/*	--*- c -*--
 * Copyright (C) 2016 Enrico Scholz <enrico.scholz@sigma-chemnitz.de>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#define IS_DEBUG_MODE	1
#define DEBUG		1

#include <arm_neon.h>
#include <string.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>

#include "bayer2rgb.h"
#include "bayer2rgb-internal.h"

#include "compiler.h"

static inline uint16x8_t	merge2v(uint8x8_t prev, uint8x8_t next)
{
	return vaddl_u8(prev, next);
}

static inline uint8x8_t		merge2v_r(uint8x8_t prev, uint8x8_t next)
{
	return vqrshrn_n_u16(merge2v(prev, next), 1);

}

/*

   v     = [ c0, c1, c2, c3, c4, c5, c6, c7 ]
   shft  = [ c1, c2, c3, c4, c5, c6, c7, c0 ]
   sum0  = [ (c0+c1), (c2+c3), (c4+c5), (c6+c7) ]
   sum1  = [ (c1+c2), (c3+c4), (c5+c6), (c7+c0) ]
   sumh2 = [ [ (c0+c1), (c2+c3), (c4+c5), (c6+c7) ],
             [ (c1+c2), (c3+c4), (c5+c6), (c7+c0) ] ],
   sum   = [ (c0+c1), (c1+c2), (c2+c3), (c3+c4),
             (c4+c5), (c5+c6), (c6+c7), (c7+c0) ]
 */
static inline uint16x8_t _always_inline_	merge2h(uint8x8_t v)
{
	uint8x8_t	shft  = vext_u8(v, v, 1);
	uint16x4_t	sum0  = vpaddl_u8(v);
	uint16x4_t	sum1  = vpaddl_u8(shft);
	uint16x4x2_t	sumh2 = vzip_u16(sum0, sum1);
	uint16x8_t	sum   = vcombine_u16(sumh2.val[0], sumh2.val[1]);

	return sum;
}

/*
   v    = [ c0, c1, c2, c3, c4, c5, c6, c7 ]
   return [ (c0+c1)/2, (c1+c2)/2, (c2+c3)/2, (c3+c4)/2,
            (c4+c5)/2, (c5+c6)/2, (c6+c7)/2, (c7+c0)/2 ]
*/
static inline uint8x8_t _always_inline_		merge2h_r(uint8x8_t v)
{
	return vqrshrn_n_u16(merge2h(v), 1);
}

/* diamond merge */
static inline uint16x8_t _always_inline_	merge4d(uint8x8_t hor,
							uint16x8_t sumv)
{
	uint16x8_t	sumh  = merge2h(hor);
	uint16x8_t	sum   = vaddq_u16(sumh, sumv);

	return sum;
}

static inline uint16x8_t _always_inline_
merge4dv(uint8x8_t hor, uint8x8_t prev, uint8x8_t next)
{
	uint16x8_t	sumv  = vaddl_u8(prev, next);

	return merge4d(hor, sumv);
}

/* rounded average of diamond merge */
static inline uint8x8_t  _always_inline_
merge4d_r(uint8x8_t hor, uint16x8_t sumv)
{
	return vqrshrn_n_u16(merge4d(hor, sumv), 2);
}


static inline uint8x8_t  _always_inline_
merge4dv_r(uint8x8_t hor, uint8x8_t prev, uint8x8_t next)
{
	return vqrshrn_n_u16(merge4dv(hor, prev, next), 2);
}


/* quadrat merge */
static inline uint16x8_t _always_inline_
merge4q(uint8x8_t prev, uint8x8_t next)
{
	uint16x8_t	merge0 = merge2h(prev);
	uint16x8_t	merge1 = merge2h(next);
	uint16x8_t	merge  = vaddq_u16(merge0, merge1);

	return merge;
}

/* rounded quadrat merge */
static inline uint8x8_t _always_inline_
merge4q_r(uint8x8_t prev, uint8x8_t next)
{
	return vqrshrn_n_u16(merge4q(prev, next), 2);
}

static inline uint8x8_t _always_inline_ pixel_shift(uint8x8_t v, int cnt)
{
	return vext_u8(v, v, cnt);
}

static inline uint8x16_t _always_inline_
combine_pixels(uint8x8_t even, uint8x8_t odd)
{
	uint8x8x2_t	tmp = vzip_u8(even, odd);

	return vcombine_u8(tmp.val[0], tmp.val[1]);
}

static inline void  _always_inline_
convert_pixel_gX8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_c0, unsigned int offs_g,
			unsigned int offs_c1,
			uint8x8_t g_prev, uint8x8_t c0_prev,
			uint8x8_t c1,     uint8x8_t g_c1,
			uint8x8_t g_next, uint8x8_t c0_next)
{
	uint8x16x4_t	row0 = { };

	row0.val[offs_g]  = combine_pixels(g_c1,
					   merge4dv_r(g_c1,
						      pixel_shift(g_prev, 1),
						      pixel_shift(g_next, 1)));

	row0.val[offs_c0] = combine_pixels(merge2v_r(c0_prev, c0_next),
					   merge4q_r(c0_prev, c0_next));

	row0.val[offs_c1] = combine_pixels(merge2h_r(c1),
					   pixel_shift(c1, 1));

	vst4q_u8(row_ptr, row0);
}

static inline void  _always_inline_
convert_pixel_Xg8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_c0, unsigned int offs_g,
			unsigned int offs_c1,
			uint8x8_t c0_prev, uint8x8_t g_prev,
			uint8x8_t g_c1,    uint8x8_t c1,
			uint8x8_t c0_next, uint8x8_t g_next)
{
	uint8x16x4_t	row0 = { };

	row0.val[offs_c0] = combine_pixels(merge4q_r(c0_prev, c0_next),
					   pixel_shift(merge2v_r(c0_prev,
								 c0_next), 1));

	row0.val[offs_g]  = combine_pixels(merge4dv_r(g_c1, g_prev, g_next),
					   pixel_shift(g_c1, 1));

	row0.val[offs_c1] = combine_pixels(c1, merge2h_r(c1));

	vst4q_u8(row_ptr, row0);
}

static inline void  _always_inline_
convert_pixel_gb8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_r, unsigned int offs_g,
			unsigned int offs_b,
			uint8x8_t g_prev, uint8x8_t r_prev,
			uint8x8_t b,      uint8x8_t g_b,
			uint8x8_t g_r,    uint8x8_t r)
{
	convert_pixel_gX8_rgb32(row_ptr, offs_r, offs_g, offs_b,
				g_prev, r_prev,  b, g_b,   g_r, r);
}

static inline void  _always_inline_
convert_pixel_gr8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_r, unsigned int offs_g,
			unsigned int offs_b,
			uint8x8_t g_prev, uint8x8_t b_prev,
			uint8x8_t r,      uint8x8_t g_r,
			uint8x8_t g_b,    uint8x8_t b)
{
	convert_pixel_gX8_rgb32(row_ptr, offs_b, offs_g, offs_r,
				g_prev, b_prev,  r, g_r,   g_b, b);
}

static inline void  _always_inline_
convert_pixel_rg8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_r, unsigned int offs_g,
			unsigned int offs_b,
			uint8x8_t b,      uint8x8_t g_b,
			uint8x8_t g_r,    uint8x8_t r,
			uint8x8_t b_next, uint8x8_t g_next)
{
	convert_pixel_Xg8_rgb32(row_ptr, offs_b, offs_g, offs_r,
				b, g_b,  g_r, r,  b_next, g_next);
}

static inline void  _always_inline_
convert_pixel_bg8_rgb32(void * __restrict__ row_ptr,
			unsigned int offs_r, unsigned int offs_g,
			unsigned int offs_b,
			uint8x8_t r_prev, uint8x8_t g_prev,
			uint8x8_t g_b,    uint8x8_t b,
			uint8x8_t r,      uint8x8_t g_r)
{
	convert_pixel_Xg8_rgb32(row_ptr, offs_r, offs_g, offs_b,
				r_prev, g_prev,  g_b, b,  r, g_r);
}

#define do_prefetch(_addr, _lvl) do { \
		__builtin_prefetch(((void const *)(_addr)) + 64, 0, (_lvl)); \
	} while (0);

static void run_tests(void);

void bayer2rgb_convert_neon(struct image_in const * __restrict__ input,
			    struct image_out const * __restrict__ output,
			    struct image_conversion_info *info)
{
	static unsigned int const	border = 2;
	static unsigned int const	rows_per_loop = 2;
	bool				handled = false;

	run_tests();

	if (input->info.h <= 2*border ||
	    input->info.w <= 2*border)
		return;

	switch (output->type) {
	case RGB_FMT_RGBx:
		switch (output->info.bpp) {
		case 32: {
			typedef struct rgbx32_pixel	rgb_pixel_t;
#include "convert-neon-infmt.inc.h"
			break;
		}

		default:
			set_fallback_reason(info, "neon: RGBx bpp");
			break;
		}
		break;

	case RGB_FMT_BGRx:
		switch (output->info.bpp) {
		case 32: {
			typedef struct bgrx32_pixel	rgb_pixel_t;
#include "convert-neon-infmt.inc.h"
			break;
		}

		default:
			set_fallback_reason(info, "neon: BGRx bpp");
			break;
		}
		break;

	case RGB_FMT_xBGR:
		switch (output->info.bpp) {
		case 32: {
			typedef struct xbgr32_pixel	rgb_pixel_t;
#include "convert-neon-infmt.inc.h"
			break;
		}

		default:
			set_fallback_reason(info, "neon: xBGR bpp");
			break;
		}
		break;

	case RGB_FMT_xRGB:
		switch (output->info.bpp) {
		case 32: {
			typedef struct xrgb32_pixel	rgb_pixel_t;
#include "convert-neon-infmt.inc.h"
			break;
		}

		default:
			set_fallback_reason(info, "neon: xRGB bpp");
			break;
		}
		break;

	default:
		set_fallback_reason(info, "neon: unsupported output format");

		/* TODO: handle other output formats */
		break;
	}

	if (!handled)
		bayer2rgb_convert_c_opt(input, output, info);
}

#include "convert-neon_tests.inc.h"
